/*************************************************************************
 *
 * Hitchhiker's Guide to the IBM PE 
 * Matrix Inversion Program - First parallel implementation
 * Chapter 2 - The Answer is 42
 *
 * To compile:
 * mpcc -g -o inverse_parallel inverse_parallel.c
 *
 *************************************************************************/

#include<stdlib.h>
#include<stdio.h>
#include<assert.h>
#include<errno.h>
#include<mpi.h>

float determinant(float **matrix,int size, int * used_rows, int * used_cols, int depth);
float coefficient(float **matrix,int size, int row, int col);
void print_matrix(FILE * fptr,float ** mat,int rows, int cols);

float test_data[8][8] =  { 
			   {4.0, 2.0, 4.0, 5.0, 4.0, -2.0, 4.0, 5.0}, 
			   {4.0, 2.0, 4.0, 5.0, 3.0, 9.0, 12.0, 1.0 }, 
			   {3.0, 9.0, -13.0, 15.0, 3.0, 9.0, 12.0, 15.0}, 
			   {3.0, 9.0, 12.0, 15.0, 4.0, 2.0, 7.0, 5.0 }, 
			   {2.0, 4.0, -11.0, 10.0, 2.0, 4.0, 11.0, 10.0 },
			   {2.0, 4.0, 11.0, 10.0, 3.0, -5.0, 12.0, 15.0 },
			   {1.0, -2.0, 4.0, 10.0, 3.0, 9.0, -12.0, 15.0 } ,
			   {1.0, 2.0, 4.0, 10.0, 2.0, -4.0, -11.0, 10.0 } ,
};
#define ROWS 8
int me, tasks, tag=0;

int main(int argc, char **argv)
{
  float **matrix;
  float **inverse;
  int rows,i,j;
  float determ;
  int * used_rows, * used_cols;

  MPI_Status status[ROWS];              /* Status of messages */
  MPI_Request req[ROWS];		/* Message IDs */

  MPI_Init(&argc,&argv);       		/* Initialize MPI */
  MPI_Comm_size(MPI_COMM_WORLD,&tasks);	/* How many parallel tasks are there?*/
  MPI_Comm_rank(MPI_COMM_WORLD,&me);	/* Who am I? */

  rows = ROWS;

  /* We need exactly one task for each row of the matrix plus one task */
  /* to act as coordinator.  If we don't have this, the last task      */
  /* reports the error (so everybody doesn't put out the same message  */
  if(tasks!=rows+1)
    {
      if(me==tasks-1)
	fprintf(stderr,"%d tasks required for this demo (one more than the number of rows in matrix\n",rows+1);
      exit(-1);
    }

  /* Allocate markers to record rows and columns to be skipped */
  /* during determinant calculation                            */
  used_rows = (int *)    malloc(rows*sizeof(*used_rows));
  used_cols = (int *)    malloc(rows*sizeof(*used_cols));
  
  /* Allocate working copy of matrix and initialize it from static copy */
  matrix = (float **) malloc(rows*sizeof(*matrix));
  for(i=0;i<rows;i++)
    {
      matrix[i] = (float *) malloc(rows*sizeof(**matrix));
      for(j=0;j<rows;j++)
	matrix[i][j] = test_data[i][j];
    }

  /* Everyone computes the determinant (to avoid message transmission */
  determ=determinant(matrix,rows,used_rows,used_cols,0);

  if(me==tasks-1)
    { /* The last task acts as coordinator */
      inverse = (float **) malloc(rows*sizeof(*inverse));
      for(i=0;i<rows;i++)
	{
	  inverse[i] = (float *) malloc(rows*sizeof(**inverse));
	}
      /* Print the determinant */
      printf("The determinant of\n\n");
      print_matrix(stdout,matrix,rows,rows);
      printf("\nis %f\n",determ);
      /* Collect the rows of the inverse matrix from the other tasks */
      /* First, post a receive from each task into the appropriate row */
      for(i=0;i<rows;i++)
	{
	  MPI_Irecv(inverse[i],rows,MPI_REAL,i,tag,MPI_COMM_WORLD,&(req[i]));
	}
      /* Then wait for all the receives to complete */
      MPI_Waitall(rows,req,status);
      printf("The inverse is\n\n");
      print_matrix(stdout,inverse,rows,rows);
    }
  else
    { /* All the other tasks compute a row of the inverse matrix */
      int dest = tasks-1;
      float *one_row;
      int size = rows*sizeof(*one_row);

      one_row = (float *) malloc(size);
      for(j=0;j<rows;j++)
	{
	  one_row[j] = coefficient(matrix,rows,j,me)/determ;
	}
      /* Send the row back to the coordinator */
      MPI_Send(one_row,rows,MPI_REAL,dest,tag,MPI_COMM_WORLD);      
    }

  /* Wait for all parallel tasks to get here, then quit */
  MPI_Barrier(MPI_COMM_WORLD);
  MPI_Finalize(); 

}

float determinant(float **matrix,int size, int * used_rows, int * used_cols, int depth)
  {
    int col1, col2, row1, row2;
    int j,k;
    float total=0;
    int sign = 1;
    
    /* Find the first unused row */
    for(row1=0;row1<size;row1++)
      {
	for(k=0;k<depth;k++)
	  {
	    if(row1==used_rows[k]) break;
	  }
	if(k>=depth)  /* this row is not used */
	  break;
      }
    assert(row1<size);

    if(depth==(size-2))
      {
	/* There are only 2 unused rows/columns left */

	/* Find the second unused row */
	for(row2=row1+1;row2<size;row2++)
	  {
	    for(k=0;k<depth;k++)
	      {
		if(row2==used_rows[k]) break;
	      }
	    if(k>=depth)  /* this row is not used */
	      break;
	  }
	assert(row2<size);

	/* Find the first unused column */
	for(col1=0;col1<size;col1++)
	  {
	    for(k=0;k<depth;k++)
	      {
		if(col1==used_cols[k]) break;
	      }
	    if(k>=depth)  /* this column is not used */
	      break;
	  }
	assert(col1<size);

	/* Find the second unused column */
	for(col2=col1+1;col2<size;col2++)
	  {
	    for(k=0;k<depth;k++)
	      {
		if(col2==used_cols[k]) break;
	      }
	    if(k>=depth)  /* this column is not used */
	      break;
	  }
	assert(col2<size);

	/* Determinant = m11*m22-m12*m21 */
	return matrix[row1][col1]*matrix[row2][col2]-matrix[row1][col2]*matrix[row2][col1];
      }

    /* There are more than 2 rows/columns in the matrix being processed  */
    /* Compute the determinant as the sum of the product of each element */
    /* in the first row and the determinant of the matrix with its row   */
    /* and column removed                                                */
    total = 0;

    used_rows[depth] = row1;
    for(col1=0;col1<size;col1++)
      {
	for(k=0;k<depth;k++)
	  {
	    if(col1==used_cols[k]) break;
	  }
	if(k<depth)  /* This column is used -- skip it*/
	  continue;
	used_cols[depth] = col1;
	total += sign*matrix[row1][col1]*determinant(matrix,size,used_rows,used_cols,depth+1);
	sign=(sign==1)?-1:1;
      }
    return total;
	    
  }

void print_matrix(FILE * fptr,float ** mat,int rows, int cols)
{
  int i,j;
  for(i=0;i<rows;i++)
    {
      for(j=0;j<cols;j++)
	{
	  fprintf(fptr,"%10.4f ",mat[i][j]);
	}
      fprintf(fptr,"\n");
    }
}

float coefficient(float **matrix,int size, int row, int col)
{
  float coef;
  int * ur, *uc;

  ur = malloc(size*sizeof(matrix));
  uc = malloc(size*sizeof(matrix));
  ur[0]=row;
  uc[0]=col;
  coef = (((row+col)%2)?-1:1)*determinant(matrix,size,ur,uc,1);
  return coef;
}

